{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5bdfa88d-1166-43b0-853d-b62009979e4c",
   "metadata": {},
   "source": [
    "# Codex on Penguins Dataset Original Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "a3b46129-4458-4f10-b4c7-e220a32e43f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import re\n",
    "import time\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from pprint import pprint\n",
    "from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "6e9a02a6-b084-4df3-b26c-50fd16758ca5",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai.api_key = \"sk-\"  # hao's key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "6a692864-00dd-4784-8d27-4af9dbe7b3c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] +\n",
    "                       [wait_fixed(5) for i in range(2)] +\n",
    "                       [wait_fixed(10)]))\n",
    "def completion_with_backoff(**kwargs):\n",
    "    return openai.ChatCompletion.create(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "b5690ce3-da4d-4825-8180-255b8759ba90",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = json.load(open('data/penguins_in_a_table.json'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "f1ac22ca-f51c-4d7b-9c8c-8d13ae6dcebe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_ans(ans_model):\n",
    "    ans_model = ans_model.split('\\n')\n",
    "    ans = []\n",
    "    residual = []\n",
    "    for li, al in enumerate(ans_model):\n",
    "        ans.append(al)\n",
    "        if('answer is' in al):\n",
    "            break\n",
    "    residual = list(ans_model[li + 1:])\n",
    "    ans = '\\n'.join(ans)\n",
    "    residual = '\\n'.join(residual)\n",
    "    return ans, residual"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6502dfdd-78e0-4b6c-a1d3-cb38ae861ff3",
   "metadata": {},
   "source": [
    "# Original Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "7292498a-5943-4216-a986-66687177b4b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_original = open('lib_prompt/prompt_original.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "45aa1e63-e4b2-4c6a-876c-3e00d62fbcd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 146/146 [12:20<00:00,  5.07s/it]\n"
     ]
    }
   ],
   "source": [
    "acc = 0\n",
    "total = 0\n",
    "with open('outputs/test_codex_original.txt', 'w') as fd:\n",
    "    for d in tqdm(data['examples']):\n",
    "        q = d['input']\n",
    "        a = d['target'][1]\n",
    "        \n",
    "        prompt_q = prompt_original + '\\nQ: ' + q + '\\n'\n",
    "        prompt_q += \"A: Let's think step by step.\"\n",
    "\n",
    "        response = completion_with_backoff(\n",
    "            model=\"gpt-3.5-turbo\",\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": \"Follow the given examples and answer the question.\"},\n",
    "                {\"role\": \"user\", \"content\": prompt_q},\n",
    "            ],\n",
    "            temperature=0, \n",
    "        )\n",
    "        \n",
    "        ans_model = response['choices'][0]['message']['content']\n",
    "\n",
    "        ans, residual = extract_ans(ans_model)\n",
    "\n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_model, a))\n",
    "        \n",
    "        ans_ = ans_model.split('answer is ')\n",
    "        if(len(ans_) > 1): ans_ = ans_[1][1]\n",
    "        else: ans_ = 'A'\n",
    "\n",
    "        if(ans_ == a): acc += 1\n",
    "        total += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "4061fdd0-5f8e-40fc-9c27-1a7ee627405a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total 146 correct 115 acc 0.7877\n"
     ]
    }
   ],
   "source": [
    "print('Total %d correct %d acc %.4f' % (total, acc, acc / total))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "394d3d21f71232a2e1ff4fbe70d3f8f4c4fb55ac495a47446cbd463fb339819b"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
